#include "RayTracer.h"
#include "handleArgs.h"
#include "PNGImage.h"
#include "Sphere.h"
#include "Triangle.h"
#include "Ray.h"
#include "Vector3D.h"
#include <iostream>
#include "LambertianShader.h"
#include "BlinnPhongShader.h"
#include <sstream>
#include <fstream>
#include "Constants.h"
#include <pthread.h>

using namespace std;

RayTracer::RayTracer(int argc, char** argv) {
    cs5721::GraphicsArgs gArgs;
    gArgs.process(argc, argv);

    inputFileName = gArgs.inputFileName;
    outputFileName = gArgs.outputFileName;
    imageWidth = gArgs.width;
    imageHeight = gArgs.height;

    imageData = new float[ (int) imageWidth * (int) imageHeight * 3 ];

    readSceneFile();
    root = new BvhNode(surfaces, 0);
    //root->print();
}

void RayTracer::start() {
    cs5721::PNGImage pngimage;
    pthread_t thread1, thread2, thread3, thread4;
    scene scene1, scene2, scene3, scene4;

    scene1.imageHeight = imageHeight;
    scene1.imageWidth = imageWidth;
    scene1.cameras = cameras;
    scene1.lights = lights;
    scene1.root = root;
    scene1.imageData = imageData;
    scene1.start = 0.0;
    scene1.finish = imageHeight / 4.0;
    scene1.timer = true;

    scene2.imageHeight = imageHeight;
    scene2.imageWidth = imageWidth;
    scene2.cameras = cameras;
    scene2.lights = lights;
    scene2.root = root;
    scene2.imageData = imageData;
    scene2.start = imageHeight / 4.0;
    scene2.finish = 2.0 * imageHeight / 4.0;
    scene2.timer = false;

    scene3.imageHeight = imageHeight;
    scene3.imageWidth = imageWidth;
    scene3.cameras = cameras;
    scene3.lights = lights;
    scene3.root = root;
    scene3.imageData = imageData;
    scene3.start = 2.0 * imageHeight / 4.0;
    scene3.finish = 3.0 * imageHeight / 4.0;
    scene3.timer = false;

    scene4.imageHeight = imageHeight;
    scene4.imageWidth = imageWidth;
    scene4.cameras = cameras;
    scene4.lights = lights;
    scene4.root = root;
    scene4.imageData = imageData;
    scene4.start = 3.0 * imageHeight / 4.0;
    scene4.finish = imageHeight;
    scene4.timer = false;

    scene *scene1ptr = &scene1;
    scene *scene2ptr = &scene2;
    scene *scene3ptr = &scene3;
    scene *scene4ptr = &scene4;

    int return1 = pthread_create(&thread1, NULL, render, (void *) scene1ptr);
    int return2 = pthread_create(&thread2, NULL, render, (void *) scene2ptr);
    int return3 = pthread_create(&thread3, NULL, render, (void *) scene3ptr);
    int return4 = pthread_create(&thread4, NULL, render, (void *) scene4ptr);

    pthread_join(thread1, NULL);
    pthread_join(thread2, NULL);
    pthread_join(thread3, NULL);
    pthread_join(thread4, NULL);
    cout << "Thread 1 returns: " << return1 << endl;
    cout << "Thread 2 returns: " << return2 << endl;
    cout << "Thread 3 returns: " << return3 << endl;
    cout << "Thread 4 returns: " << return4 << endl;
    cout << "Done!" << endl << endl;

    pngimage.writeFileData(outputFileName, (int) imageWidth, (int) imageHeight, imageData);
    delete [] imageData;

}

static void *render(void *ptr) {
    double random;
    RayTracer::scene *myScene = (RayTracer::scene *) ptr;
    vector<Camera> myCameras = myScene->cameras;

    Camera cam = myCameras[0];

    double focalLength = cam.getFocalLength();
    double imagePW = cam.getImagePlaneWidth();
    double imagePH = (imagePW * myScene->imageHeight) / myScene->imageWidth;
    cs5721::Vector3D origin = cam.getPosition();
    cs5721::Vector3D U = cam.orthonormalFrame[0];
    cs5721::Vector3D V = cam.orthonormalFrame[1];
    cs5721::Vector3D W = cam.orthonormalFrame[2];

    for (double h = myScene->start; h < myScene->finish; h++) {
        for (double w = 0.0; w < myScene->imageWidth; w++) {

            cs5721::Vector3D color(0.0, 0.0, 0.0);

            if (antiAliasing) {
                for (int p = 0; p < gridSize; p++) {
                    for (int q = 0; q < gridSize; q++) {

                        random = (static_cast<double> (rand()) + 1) / (RAND_MAX + 1);
                        double u = -imagePH + (imagePH + imagePH)
                                * (h + (p + random) / gridSize) / myScene->imageHeight;

                        random = (static_cast<double> (rand()) + 1) / (RAND_MAX + 1);
                        double v = -imagePW + (imagePW + imagePW)
                                * (w + (q + random) / gridSize) / myScene->imageWidth;

                        cs5721::Vector3D direction = W*-focalLength + U * u + V*v;

                        Ray viewRay(origin, direction);

                        color += rayTrace(viewRay, 0, myScene);
                    }
                }
                color /= (double) (gridSize * gridSize);
                
            } else {
                double u = -imagePH + (imagePH + imagePH)
                        * (h + 0.5) / myScene->imageHeight;

                double v = -imagePW + (imagePW + imagePW)
                        * (w + 0.5) / myScene->imageWidth;

                cs5721::Vector3D direction = W*-focalLength + U * u + V*v;

                Ray viewRay(origin, direction);

                color += rayTrace(viewRay, 0, myScene);
            }

            color.clamp();

            int idx = (w * myScene->imageWidth * 3) + h * 3;
            myScene->imageData[idx + 0] = color[0];
            myScene->imageData[idx + 1] = color[1];
            myScene->imageData[idx + 2] = color[2];
        }
    }
}

static cs5721::Vector3D rayTrace(Ray& ray, int rD, RayTracer::scene * myScene) {
    int reflectionDepth = rD;
    std::vector<Camera> cameras = myScene->cameras;
    std::vector<Light> lights = myScene->lights;
    Record rec;
    cs5721::Vector3D color(0.0, 0.0, 0.0);

    //If object was hit
    if (myScene->root->isHit(ray, eps, infinity, rec)) {
        cs5721::Vector3D pointHit = ray.getOrigin() + (ray.getDirection() * rec.t);


        //If object is reflective
        if (rec.shader->getMirrorCoef() > 0.0 && reflectionDepth < maxDepth) {
            double reflection = 2.0 * (ray.getDirection().dot(rec.normal));
            cs5721::Vector3D reflectDir = ray.getDirection() - (rec.normal * reflection);
            Ray reflectRay(pointHit, reflectDir);

            cs5721::Vector3D reflectColor = rayTrace(reflectRay, ++reflectionDepth, myScene);
            color += reflectColor * rec.shader->getMirrorCoef();
        }

        //For each light
        for (int i = 0; i < lights.size(); i++) {
            Light curLight = lights[i];
            Record tempRecord;

            cs5721::Vector3D light = curLight.getPosition() - pointHit;
            light.normalize();
            cs5721::Vector3D view = cameras[0].getPosition() - pointHit;
            view.normalize();

            Ray shadowRay(pointHit, light);

            if (!myScene->root->isHit(shadowRay, eps, infinity, tempRecord)) {
                cs5721::Vector3D lightIntensity = curLight.getIntensity();
                cs5721::Vector3D tempColor = rec.shader->calcColor(
                        lightIntensity, light, rec.normal, view);
                color += tempColor;
            }
        }
    } else {
        color.set(0.0, 0.0, 0.0); //background color
    }

    return color;
}

void RayTracer::readSceneFile() {
    ifstream sceneFile;
    string line;

    sceneFile.open(inputFileName.c_str());
    if (sceneFile.is_open()) {
        while (getline(sceneFile, line)) {
            if ((line.substr(0, 5).compare("shape")) == 0) {
                if ((line.substr(6, 6).compare("sphere")) == 0) {
                    processSphere(line);
                } else {
                    processTriangle(line);
                }
            } else if ((line.substr(0, 5).compare("light")) == 0) {
                processLight(line);
            } else if ((line.substr(0, 6).compare("camera")) == 0) {
                processCamera(line);
            } else {
                cout << "Error: RayTracer: readSceneFile - bad 'shape' read" << endl;
                exit(1);
            }
        }
        sceneFile.close();
    } else {
        cout << "Error: RayTracer: readSceneFile - couldnt open file" << endl;
        exit(1);
    }
}

void RayTracer::processSphere(string s) {
    double x, y, z, radius, dRed, dGreen, dBlue,
            sRed, sGreen, sBlue, phongExp, mirror;
    char m = 'z';
    Shader* shader;
    string temp, shaderString;
    istringstream iss(s, istringstream::in);

    iss >> temp;
    iss >> temp;
    iss >> x;
    iss >> y;
    iss >> z;
    iss >> radius;
    iss >> shaderString;
    iss >> dRed;
    iss >> dGreen;
    iss >> dBlue;

    if (shaderString.compare(lambertian) == 0) {
        shader = new LambertianShader(dRed, dGreen, dBlue);
        //cout << "made new lambertian shader - sphere" << endl;
    } else {
        iss >> sRed;
        iss >> sGreen;
        iss >> sBlue;
        iss >> phongExp;
        iss >> m;
        iss >> mirror;

        if (m != 'm') {
            mirror = 0.0;
        }

        shader = new BlinnPhongShader(dRed, dGreen, dBlue, sRed, sGreen, sBlue, phongExp, mirror);
        //cout << "made new blinnphong shader - sphere" << sRed<<" " << sGreen<< " " << sBlue<<" " << phongExp << endl;
    }
    //
    //
    //
    //
    //    x = atof(s.substr(13, 3).c_str());
    //    y = atof(s.substr(17, 3).c_str());
    //    z = atof(s.substr(21, 5).c_str());
    //    radius = atof(s.substr(27, 3).c_str());

    //cout << "shaderString: " << shaderString << endl;
    Surface *sphere = new Sphere(x, y, z, radius, shader);
    surfaces.push_back(sphere);
}

void RayTracer::processTriangle(string s) {
    double x1, x2, x3, y1, y2, y3, z1, z2, z3;
    double dRed, dGreen, dBlue,
            sRed, sGreen, sBlue, phongExp, mirror;
    char m = 'z';
    Shader* shader;
    string temp, shaderString;
    istringstream iss(s, istringstream::in);

    iss >> temp;
    iss >> temp;
    iss >> x1;
    iss >> y1;
    iss >> z1;
    iss >> x2;
    iss >> y2;
    iss >> z2;
    iss >> x3;
    iss >> y3;
    iss >> z3;

    iss >> shaderString;
    iss >> dRed;
    iss >> dGreen;
    iss >> dBlue;

    if (shaderString.compare(lambertian) == 0) {
        shader = new LambertianShader(dRed, dGreen, dBlue);
    } else {
        iss >> sRed;
        iss >> sGreen;
        iss >> sBlue;
        iss >> phongExp;
        iss >> m;
        iss >> mirror;

        if (m != 'm') {
            mirror = 0.0;
        }

        shader = new BlinnPhongShader(dRed, dGreen, dBlue, sRed, sGreen, sBlue, phongExp, mirror);
    }

    Surface *triangle = new Triangle(x1, y1, z1, x2, y2, z2, x3, y3, z3, shader);
    surfaces.push_back(triangle);
}

void RayTracer::processCamera(string s) {
    double x, y, z, u, v, w, focalLength, imagePW;
    string temp;
    istringstream iss(s, istringstream::in);

    iss >> temp;
    iss >> x;
    iss >> y;
    iss >> z;
    iss >> u;
    iss >> v;
    iss >> w;
    iss >> focalLength;
    iss >> imagePW;

    //cout << "Camera: " << x << " " << y << " "<< z << " "<< u << " "<< v<< " " << w << endl;

    Camera cam(x, y, z, u, v, w, focalLength, imagePW);
    cameras.push_back(cam);
}

void RayTracer::processLight(string s) {
    double x, y, z, r, g, b;
    string temp;
    istringstream iss(s, istringstream::in);

    iss >> temp;
    iss >> x;
    iss >> y;
    iss >> z;
    iss >> r;
    iss >> g;
    iss >> b;

    //clamp color values
    if (r > 1.0) r = 1.0;
    if (g > 1.0) g = 1.0;
    if (b > 1.0) b = 1.0;

    Light l(x, y, z, r, g, b);
    lights.push_back(l);
}

